# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import json
import logging
from pathlib import Path

from ..datasets.canary_exploit.verify_response import verify_responses

from .benchmark import Benchmark
from .query_llm import query_llm_to_generate_responses


class CanaryExploitBenchmark(Benchmark):
    """
    This class implements the Canary Exploit benchmark. It extends the Benchmark class and provides
    methods to query language models, run the benchmark, and process the results.
    """

    @classmethod
    def return_kind(cls) -> list[str]:
        return ["canary-exploit"]

    def query_llm_to_generate_responses(
        self, prompt_path: Path, run_llm_in_parallel: int = 1
    ) -> None:
        """
        Processes a dataset of prompts by sending them to the LLM and saving the responses in a file.
        """
        response_result = []
        guided_decode_json_schema = '{"type": "object", "properties": {"answer": {"type": "string"}}, "required": ["answer"]}'
        for llm in self.llms_under_test:
            logging.info(f"Querying {llm.model}")
            response_result.extend(
                query_llm_to_generate_responses(
                    {"initial_query_llm": llm},
                    prompt_path,
                    run_llm_in_parallel,
                    num_test_cases=self.num_test_cases,
                    pass_k=self.pass_k,
                    guided_decode_json_schema=guided_decode_json_schema,
                )
            )

        self.response_path.write_text(json.dumps(response_result, indent=4))

    async def run(
        self,
        num_test_cases: int = 0,
        run_llm_in_parallel: int = 16,
        should_cleanup_after_eval: bool = True,
    ) -> None:
        verify_responses(self.response_path, self.judge_response_path, self.stat_path)
